[CK] Fix MoE 2-stage dispatch for non-128-divisible inter_dim#3973
Open
jonahbernard wants to merge 1 commit into
Open
[CK] Fix MoE 2-stage dispatch for non-128-divisible inter_dim#3973jonahbernard wants to merge 1 commit into
jonahbernard wants to merge 1 commit into
Conversation
…ances The gfx950 heuristic dispatch sent all inter_dim > 192 shapes to the NPerBlock/KPerBlock=128 fast path, which fails CK's N%NPerBlock (stage1) and K%KPerBlock (stage2) divisibility checks when inter_dim is not a multiple of 128 (e.g. DiffusionGemma moe_inter=704=64*11). Route those shapes to the PerBlock=64 instances, which divide any multiple of 64.
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
The gfx950 heuristic dispatch sent all inter_dim > 192 shapes to the NPerBlock/KPerBlock=128 fast path, which fails CK's N%NPerBlock (stage1) and K%KPerBlock (stage2) divisibility checks when inter_dim is not a multiple of 128 (e.g. DiffusionGemma moe_inter=704=64*11). Route those shapes to the PerBlock=64 instances, which divide any multiple of 64.
Technical Details
Widen pre-existing dispatch heuristic
if (inter_dim <= 192)toif (inter_dim <= 192 || inter_dim % 128 != 0).Test Plan
Verified on gfx950 with the no-quant (a16w16) legacy CK 2-stage path at
inter_dim=704, sweeping token counts from 1 to 163840 (covering allblock_mselections: 32 / 64 / 128):Test Result
Submission Checklist